import torch
import torch.nn as nn


def channel_shuffle(x, groups):
    batch_size, num_channels, height, width = x.size()
    channels_per_group = num_channels // groups
    x = x.reshape(batch_size, groups, channels_per_group, height, width)
    x = x.transpose(1, 2).reshape(batch_size, -1, height, width)
    return x


def depthwise_conv(input_c, output_c, kernel_s, stride=1, padding=0, bias=False):
    return nn.Conv2d(in_channels=input_c, out_channels=output_c, kernel_size=kernel_s,
                     stride=stride, padding=padding, bias=bias, groups=input_c)


class InvertedResidual(nn.Module):
    def __init__(self, input_c, output_c, stride):
        super().__init__()
        assert stride in [1, 2]
        self.stride = stride
        assert output_c % 2 == 0
        branch_features = output_c // 2
        assert (self.stride != 1) or (input_c == branch_features * 2)
        if self.stride == 2:
            self.branch1 = nn.Sequential(
                depthwise_conv(input_c, input_c, kernel_s=3, stride=self.stride, padding=1),
                nn.BatchNorm2d(input_c),
                nn.Conv2d(input_c, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(branch_features),
                nn.LeakyReLU(negative_slope=0.1, inplace=True))
        else:
            self.branch1 = nn.Sequential()
        self.branch2 = nn.Sequential(
            nn.Conv2d(input_c if self.stride > 1 else branch_features, branch_features, kernel_size=1,
                      stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            depthwise_conv(branch_features, branch_features, kernel_s=3, stride=self.stride, padding=1),
            nn.BatchNorm2d(branch_features),
            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.LeakyReLU(negative_slope=0.1, inplace=True))

    def forward(self, x):
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
        out = channel_shuffle(out, 2)
        return out


class ShuffleNet(nn.Module):
    def __init__(self):
        super().__init__()
        stage_out_channels = [48, 96, 192, 384, 768]
        input_channels = 3
        output_channels = 24
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.LeakyReLU(negative_slope=0.1, inplace=True))
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        stage_names = ["stage{}".format(i) for i in [2, 3, 4, 5, 6]]
        input_channels = output_channels
        for name, output_channels in zip(stage_names, stage_out_channels):
            seq = [InvertedResidual(input_channels, output_channels, 2),
                   InvertedResidual(output_channels, output_channels, 1)]
            setattr(self, name, nn.Sequential(*seq))
            input_channels = output_channels

    def forward(self, x):  # Nx3x512x640
        x = self.conv1(x)  # Nx12x256x320
        x = self.maxpool(x)  # Nx12x128x160
        x = self.stage2(x)  # Nx24x64x80
        x = self.stage3(x)  # Nx96x32x40
        x = self.stage4(x)  # Nx192x16x20
        x = self.stage5(x)  # Nx384x8x10
        x = self.stage6(x)  # Nx768x4x5
        return x
